{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b2748b09-6958-4cb0-abe9-7b76fdacedac",
   "metadata": {},
   "source": [
    "## Structured generation with LLM（2）：Function Calling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4df3fa8-0a6b-4c77-ac77-44e0d11071b7",
   "metadata": {},
   "source": [
    "第二期，使用[智谱AI](https://open.bigmodel.cn/)的glm-4 API进行实验"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "085f6ba4-3e70-4458-9240-68a0c9cae96c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# 本课程需要提前安装zhipuai\n",
    "# !pip3 install zhipuai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "facf43e8-a12d-407c-b443-2aebb97501d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from zhipuai import ZhipuAI\n",
    "client = ZhipuAI(api_key=\"xxx\") # 请填写自己的APIKey"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16a575ff-c9e2-4529-b5ca-a2a98ee40065",
   "metadata": {},
   "source": [
    "## Example 1: 中文翻译器\n",
    "\n",
    "效果：输入任意文本，返回{\"chinese\": 翻译结果}\n",
    "\n",
    "使用Function Calling的思路是：\n",
    "* 把你想返回的schema用json schema表示，放到tools列表里\n",
    "* 将tools传给model\n",
    "* 在prompt/system prompt中，强调模型要始终调用这个tool\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ebc288a8-4aca-47db-8d7b-76372fc67a32",
   "metadata": {},
   "outputs": [],
   "source": [
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_translate_result\",\n",
    "            \"description\": \"得到任意文本的中文翻译。\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"chinese\": {\n",
    "                        \"description\": \"中文翻译结果\",\n",
    "                        \"type\": \"string\"\n",
    "                    },\n",
    "                },\n",
    "                \"required\": [ \"chinese\" ]\n",
    "            },\n",
    "        }\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dd6d2408-4cb3-4bcf-b0cc-3b714db85f78",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CompletionMessage(content='', role='assistant', tool_calls=[CompletionMessageToolCall(id='call_20240724141243f4f0d6bdd45e4e2f9460732e254f42ec', function=Function(arguments='{\"chinese\": \"我们训练了一个基于GPT-4的模型，称为CriticGPT，用于捕捉ChatGPT代码输出的错误。我们发现，当人们从CriticGPT那里获得帮助来审查ChatGPT代码时，他们有60%的时间会优于那些没有帮助的人。我们开始将类似CriticGPT的模型集成到我们的RLHF标记流程中，为我们的训练师提供明确的AI辅助。这是朝着能够评估高级AI系统输出的方向迈出的一步，这些输出在没有更好的工具的情况下很难被人类评估。\"}', name='get_translate_result'), type='function', index=0)])\n"
     ]
    }
   ],
   "source": [
    "# 调用function calling\n",
    "system_prompt = '''给你一个“get_translate_result” tool，请每次都调用这个tool。'''\n",
    "\n",
    "query = \"We've trained a model, based on GPT-4, called CriticGPT to catch errors in ChatGPT's code output. We found that when people get help from CriticGPT to review ChatGPT code they outperform those without help 60% of the time. We are beginning the work to integrate CriticGPT-like models into our RLHF labeling pipeline, providing our trainers with explicit AI assistance. This is a step towards being able to evaluate outputs from advanced AI systems that can be difficult for people to rate without better tools.\"\n",
    "\n",
    "response = client.chat.completions.create(\n",
    "    model='glm-4-flash',\n",
    "    messages=[\n",
    "        {'role': 'system', 'content': system_prompt},\n",
    "        {'role': 'user', 'content': query},\n",
    "    ],\n",
    "    tools=tools,\n",
    "    do_sample=False,\n",
    ")\n",
    "\n",
    "print(response.choices[0].message)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3575d9b2-5afd-4d8b-a215-98dfdef1e398",
   "metadata": {},
   "source": [
    "-----成功，撒花🎉🎉-----\n",
    "\n",
    "中文翻译的结构化结果，已经被放在了`tool_calls`参数中；这个参数包括`tool_id`、`tool_name`还有`arguments`(我们想要的结构化输出）。\n",
    "\n",
    "对于我们的使用场景而言，直接取出`arguments`即可"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1a82ab0f-ee4c-49ff-9169-aa539de1ca59",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'chinese': '我们训练了一个基于GPT-4的模型，称为CriticGPT，用于捕捉ChatGPT代码输出的错误。我们发现，当人们从CriticGPT那里获得帮助来审查ChatGPT代码时，他们有60%的时间会优于那些没有帮助的人。我们开始将类似CriticGPT的模型集成到我们的RLHF标记流程中，为我们的训练师提供明确的AI辅助。这是朝着能够评估高级AI系统输出的方向迈出的一步，这些输出在没有更好的工具的情况下很难被人类评估。'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "json.loads(response.choices[0].message.tool_calls[0].function.arguments)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ef20dd0-b2fd-47b7-9aad-5e1161781661",
   "metadata": {},
   "source": [
    "## Example 2：评价解析\n",
    "\n",
    "预期效果：输入一段用户评价，得到评价属性（口味、价格等）、评价极性（正向、负向、中立）、评价词（好吃、贵等）、参考片段。\n",
    "\n",
    "schema如下\n",
    "\n",
    "```\n",
    "[\n",
    "    {\n",
    "        'aspect': 评价属性,\n",
    "        'sentiment': 评价极性,\n",
    "        'sentiment_word': 评价词,\n",
    "        'reference': 参考片段,\n",
    "    },\n",
    "]\n",
    "```\n",
    "\n",
    "注意，结果期望是**list of dict**，每个dict表示对于一个评价属性的解析结果，这一点需要体现在你定义的schema里面。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bcb04f98-6285-4350-8694-e3078a5ce338",
   "metadata": {},
   "outputs": [],
   "source": [
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_parse_result\",\n",
    "            \"description\": \"对评价进行解析的结果。\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"array_result\": {\n",
    "                        \"type\": \"array\", #注意这里\n",
    "                        \"description\": \"解析结果list。\",\n",
    "                        \"items\":{\n",
    "                            \"type\": \"object\",\n",
    "                            \"properties\":{\n",
    "                                \"aspect\": {\n",
    "                                \"description\": \"评价属性\",\n",
    "                                \"type\": \"string\"\n",
    "                                },\n",
    "                                \"sentiment_word\": {\n",
    "                                    \"description\": \"对评价属性的评价词，从原文中抽取\",\n",
    "                                    \"type\": \"string\"\n",
    "                                },\n",
    "                                \"sentiment\": {\n",
    "                                    \"description\": \"对评价属性的情感\",\n",
    "                                    \"type\": \"string\",\n",
    "                                    \"enum\": [\"positive\", \"negative\", \"neural\"]\n",
    "                                },\n",
    "                                \"reference\":{\n",
    "                                    \"description\": \"评价的原文\",\n",
    "                                    \"type\": \"string\"\n",
    "                                }\n",
    "                            },\n",
    "                            \"required\": [\"aspect\", \"sentiment_word\", \"sentiment\", \"reference\"]\n",
    "                        }\n",
    "                    }\n",
    "                },\"required\": [\"array_result\"]\n",
    "            },\n",
    "        }\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b1d9c8c7-c1d0-4dcd-a62d-40ff23a2ef4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CompletionMessage(content='', role='assistant', tool_calls=[CompletionMessageToolCall(id='call_202407241413131d6a6c7cbfa94886b756fb9e691fcdab', function=Function(arguments='{\"array_result\":{\"Items\":[{\"aspect\":\"环境\",\"reference\":\"整体来说，环境可以\",\"sentiment\":\"positive\",\"sentiment_word\":\"可以\"},{\"aspect\":\"味道\",\"reference\":\"味道的话也还不错\",\"sentiment\":\"positive\",\"sentiment_word\":\"还不错\"},{\"aspect\":\"价格\",\"reference\":\"但价格有一点小贵\",\"sentiment\":\"negative\",\"sentiment_word\":\"有一点小贵\"}]}}', name='get_parse_result'), type='function', index=0)])\n"
     ]
    }
   ],
   "source": [
    "# 调用function calling\n",
    "system_prompt = '''给你一个“get_parse_result” tool，请每次都调用这个tool。'''\n",
    "\n",
    "query =  \"整体来说，环境可以，味道的话也还不错，但价格有一点小贵。\"\n",
    "\n",
    "response = client.chat.completions.create(\n",
    "    model='glm-4', # 只有glm-4能理解这个任务，glm-4-air/airx/flash都不行\n",
    "    messages=[\n",
    "        {'role': 'system', 'content': system_prompt},\n",
    "        {'role': 'user', 'content': query},\n",
    "    ],\n",
    "    tools=tools,\n",
    "    do_sample=False,\n",
    ")\n",
    "\n",
    "print(response.choices[0].message)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ac7a3e6-5a01-49c2-bf66-59c414613d72",
   "metadata": {},
   "source": [
    "-----成功，撒花🎉-----"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e55f7289-ba71-480a-8252-0a2e60f553b9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'array_result': {'Items': [{'aspect': '环境',\n",
       "    'reference': '整体来说，环境可以',\n",
       "    'sentiment': 'positive',\n",
       "    'sentiment_word': '可以'},\n",
       "   {'aspect': '味道',\n",
       "    'reference': '味道的话也还不错',\n",
       "    'sentiment': 'positive',\n",
       "    'sentiment_word': '还不错'},\n",
       "   {'aspect': '价格',\n",
       "    'reference': '但价格有一点小贵',\n",
       "    'sentiment': 'negative',\n",
       "    'sentiment_word': '有一点小贵'}]}}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "json.loads(response.choices[0].message.tool_calls[0].function.arguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36a6af73-53df-4ff6-8b6f-c4d6286342cf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88fc9e50-5356-43a6-a6c7-89a5fb534250",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdbe2335-f51c-49c1-b4ac-05034e6278e4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "study",
   "language": "python",
   "name": "study"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
